{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Implementation of AlexNet [1]\n",
    "## Introduction\n",
    "In this task, we use AlexNet [1] to make image classification on CIFAR-10 dataset [2].\n",
    "1. We finetune AlexNet for 1 epoch on CIFAR-10 dataset.\n",
    "2. We evaluate the finetuned AlexNet on CIFAR-10 dataset and report its accuracy.\n",
    "\n",
    "All operations are run on the CPU.\n",
    "\n",
    "[1] Krizhevsky, Alex, Ilya Sutskever, and Geoffrey E. Hinton. \"Imagenet classification with deep convolutional neural networks.\" Advances in neural information processing systems. 2012.\\\n",
    "[2] Krizhevsky, Alex, and Geoffrey Hinton. Learning multiple layers of features from tiny images. Vol. 1. No. 4. Technical report, University of Toronto, 2009."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.utils.data as data\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as datasets\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.pyplot import imshow, imsave"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load training set of CIFAR-10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "transform_train = transforms.Compose([\n",
    "    transforms.RandomCrop(32, padding=4),\n",
    "    transforms.RandomHorizontalFlip(),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
    "])\n",
    "trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)\n",
    "trainloader = data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4)\n",
    "num_classes = 10"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build AlexNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total params: 2.47M\n"
     ]
    }
   ],
   "source": [
    "class AlexNet(nn.Module):\n",
    "\n",
    "    def __init__(self, num_classes=10):\n",
    "        super(AlexNet, self).__init__()\n",
    "        self.features = nn.Sequential(\n",
    "            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.MaxPool2d(kernel_size=2, stride=2),\n",
    "            nn.Conv2d(64, 192, kernel_size=5, padding=2),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.MaxPool2d(kernel_size=2, stride=2),\n",
    "            nn.Conv2d(192, 384, kernel_size=3, padding=1),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Conv2d(384, 256, kernel_size=3, padding=1),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Conv2d(256, 256, kernel_size=3, padding=1),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.MaxPool2d(kernel_size=2, stride=2),\n",
    "        )\n",
    "        self.classifier = nn.Linear(256, num_classes)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.features(x)\n",
    "        x = x.view(x.size(0), -1)\n",
    "        x = self.classifier(x)\n",
    "        return x\n",
    "\n",
    "model = AlexNet(num_classes = num_classes)\n",
    "print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build loss & optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load pretrained AlexNet\n",
    "Training a full network needs GPU and many times. So we finetune the network based on the provided pretrained\n",
    "model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "IncompatibleKeys(missing_keys=[], unexpected_keys=[])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "checkpoint = torch.load('pretrained/pretrained_alexnet.pth.tar')\n",
    "model.load_state_dict({k.replace('module.',''):v for k,v in checkpoint['state_dict'].items()})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define AverageMeter & accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AverageMeter(object):\n",
    "    \"\"\"Computes and stores the average and current value\n",
    "       Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262\n",
    "    \"\"\"\n",
    "    def __init__(self):\n",
    "        self.reset()\n",
    "\n",
    "    def reset(self):\n",
    "        self.val = 0\n",
    "        self.avg = 0\n",
    "        self.sum = 0\n",
    "        self.count = 0\n",
    "\n",
    "    def update(self, val, n=1):\n",
    "        self.val = val\n",
    "        self.sum += val * n\n",
    "        self.count += n\n",
    "        self.avg = self.sum / self.count\n",
    "\n",
    "def accuracy(output, target, topk=(1,)):\n",
    "    \"\"\"Computes the precision@k for the specified values of k\"\"\"\n",
    "    maxk = max(topk)\n",
    "    batch_size = target.size(0)\n",
    "\n",
    "    _, pred = output.topk(maxk, 1, True, True)\n",
    "    pred = pred.t()\n",
    "    correct = pred.eq(target.view(1, -1).expand_as(pred))\n",
    "\n",
    "    res = []\n",
    "    for k in topk:\n",
    "        correct_k = correct[:k].view(-1).float().sum(0)\n",
    "        res.append(correct_k.mul_(100.0 / batch_size))\n",
    "    return res"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Finetune for 1 epoch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10/391) Data: 0.22s | Batch: 3.10s | Total: 31.03s | ETA: 1182.06s | Loss: 0.15 | top1:  94.22 | top5:  100.00\n",
      "(20/391) Data: 0.15s | Batch: 2.85s | Total: 56.93s | ETA: 1055.98s | Loss: 0.15 | top1:  94.61 | top5:  100.00\n",
      "(30/391) Data: 0.12s | Batch: 2.74s | Total: 82.32s | ETA: 990.64s | Loss: 0.15 | top1:  94.64 | top5:  100.00\n",
      "(40/391) Data: 0.10s | Batch: 2.67s | Total: 106.92s | ETA: 938.26s | Loss: 0.16 | top1:  94.51 | top5:  99.98\n",
      "(50/391) Data: 0.09s | Batch: 2.68s | Total: 133.82s | ETA: 912.67s | Loss: 0.16 | top1:  94.39 | top5:  99.95\n",
      "(60/391) Data: 0.09s | Batch: 2.63s | Total: 157.72s | ETA: 870.11s | Loss: 0.16 | top1:  94.38 | top5:  99.95\n",
      "(70/391) Data: 0.09s | Batch: 2.61s | Total: 182.62s | ETA: 837.46s | Loss: 0.16 | top1:  94.24 | top5:  99.96\n",
      "(80/391) Data: 0.08s | Batch: 2.60s | Total: 207.82s | ETA: 807.91s | Loss: 0.16 | top1:  94.25 | top5:  99.96\n",
      "(90/391) Data: 0.08s | Batch: 2.58s | Total: 232.42s | ETA: 777.32s | Loss: 0.16 | top1:  94.28 | top5:  99.96\n",
      "(100/391) Data: 0.08s | Batch: 2.58s | Total: 257.92s | ETA: 750.54s | Loss: 0.16 | top1:  94.25 | top5:  99.96\n",
      "(110/391) Data: 0.08s | Batch: 2.57s | Total: 282.32s | ETA: 721.20s | Loss: 0.16 | top1:  94.30 | top5:  99.96\n",
      "(120/391) Data: 0.08s | Batch: 2.57s | Total: 308.02s | ETA: 695.62s | Loss: 0.16 | top1:  94.27 | top5:  99.97\n",
      "(130/391) Data: 0.08s | Batch: 2.56s | Total: 332.62s | ETA: 667.80s | Loss: 0.16 | top1:  94.27 | top5:  99.96\n",
      "(140/391) Data: 0.08s | Batch: 2.55s | Total: 356.82s | ETA: 639.73s | Loss: 0.16 | top1:  94.25 | top5:  99.97\n",
      "(150/391) Data: 0.07s | Batch: 2.55s | Total: 382.12s | ETA: 613.94s | Loss: 0.16 | top1:  94.29 | top5:  99.97\n",
      "(160/391) Data: 0.07s | Batch: 2.53s | Total: 405.52s | ETA: 585.47s | Loss: 0.16 | top1:  94.29 | top5:  99.97\n",
      "(170/391) Data: 0.07s | Batch: 2.53s | Total: 429.82s | ETA: 558.77s | Loss: 0.16 | top1:  94.38 | top5:  99.97\n",
      "(180/391) Data: 0.07s | Batch: 2.53s | Total: 456.12s | ETA: 534.67s | Loss: 0.16 | top1:  94.38 | top5:  99.97\n",
      "(190/391) Data: 0.07s | Batch: 2.53s | Total: 480.62s | ETA: 508.44s | Loss: 0.16 | top1:  94.39 | top5:  99.97\n",
      "(200/391) Data: 0.07s | Batch: 2.53s | Total: 505.22s | ETA: 482.48s | Loss: 0.16 | top1:  94.41 | top5:  99.97\n",
      "(210/391) Data: 0.07s | Batch: 2.52s | Total: 529.12s | ETA: 456.05s | Loss: 0.16 | top1:  94.41 | top5:  99.97\n",
      "(220/391) Data: 0.07s | Batch: 2.52s | Total: 554.92s | ETA: 431.32s | Loss: 0.16 | top1:  94.36 | top5:  99.97\n",
      "(230/391) Data: 0.07s | Batch: 2.52s | Total: 580.22s | ETA: 406.15s | Loss: 0.16 | top1:  94.39 | top5:  99.97\n",
      "(240/391) Data: 0.07s | Batch: 2.53s | Total: 607.52s | ETA: 382.23s | Loss: 0.16 | top1:  94.39 | top5:  99.97\n",
      "(250/391) Data: 0.07s | Batch: 2.53s | Total: 633.02s | ETA: 357.02s | Loss: 0.16 | top1:  94.38 | top5:  99.97\n",
      "(260/391) Data: 0.07s | Batch: 2.53s | Total: 658.21s | ETA: 331.64s | Loss: 0.16 | top1:  94.36 | top5:  99.97\n",
      "(270/391) Data: 0.07s | Batch: 2.53s | Total: 683.01s | ETA: 306.09s | Loss: 0.16 | top1:  94.35 | top5:  99.97\n",
      "(280/391) Data: 0.07s | Batch: 2.53s | Total: 708.31s | ETA: 280.80s | Loss: 0.16 | top1:  94.36 | top5:  99.97\n",
      "(290/391) Data: 0.07s | Batch: 2.53s | Total: 732.31s | ETA: 255.05s | Loss: 0.16 | top1:  94.34 | top5:  99.98\n",
      "(300/391) Data: 0.07s | Batch: 2.53s | Total: 757.62s | ETA: 229.81s | Loss: 0.16 | top1:  94.33 | top5:  99.98\n",
      "(310/391) Data: 0.07s | Batch: 2.52s | Total: 782.02s | ETA: 204.33s | Loss: 0.17 | top1:  94.33 | top5:  99.98\n",
      "(320/391) Data: 0.07s | Batch: 2.52s | Total: 807.02s | ETA: 179.06s | Loss: 0.16 | top1:  94.34 | top5:  99.98\n",
      "(330/391) Data: 0.07s | Batch: 2.52s | Total: 832.12s | ETA: 153.82s | Loss: 0.17 | top1:  94.36 | top5:  99.98\n",
      "(340/391) Data: 0.07s | Batch: 2.52s | Total: 855.62s | ETA: 128.34s | Loss: 0.17 | top1:  94.37 | top5:  99.98\n",
      "(350/391) Data: 0.07s | Batch: 2.52s | Total: 882.82s | ETA: 103.42s | Loss: 0.17 | top1:  94.34 | top5:  99.98\n",
      "(360/391) Data: 0.07s | Batch: 2.52s | Total: 906.62s | ETA: 78.07s | Loss: 0.17 | top1:  94.32 | top5:  99.98\n",
      "(370/391) Data: 0.07s | Batch: 2.51s | Total: 930.21s | ETA: 52.80s | Loss: 0.17 | top1:  94.29 | top5:  99.97\n",
      "(380/391) Data: 0.07s | Batch: 2.51s | Total: 955.01s | ETA: 27.65s | Loss: 0.17 | top1:  94.29 | top5:  99.97\n",
      "(390/391) Data: 0.07s | Batch: 2.51s | Total: 977.92s | ETA: 2.51s | Loss: 0.17 | top1:  94.30 | top5:  99.97\n"
     ]
    }
   ],
   "source": [
    "model.train()\n",
    "\n",
    "batch_time = AverageMeter()\n",
    "data_time = AverageMeter()\n",
    "losses = AverageMeter()\n",
    "top1 = AverageMeter()\n",
    "top5 = AverageMeter()\n",
    "end = time.time()\n",
    "\n",
    "for batch_idx, (inputs, targets) in enumerate(trainloader):\n",
    "    # measure data loading time\n",
    "    data_time.update(time.time() - end)\n",
    "    \n",
    "    inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)\n",
    "    \n",
    "    # compute output\n",
    "    outputs = model(inputs)\n",
    "    loss = criterion(outputs, targets)\n",
    "\n",
    "    # measure accuracy and record loss\n",
    "    prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))\n",
    "    losses.update(loss.data.item(), inputs.size(0))\n",
    "    top1.update(prec1.item(), inputs.size(0))\n",
    "    top5.update(prec5.item(), inputs.size(0))\n",
    "\n",
    "    # compute gradient and do SGD step\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    # measure elapsed time\n",
    "    batch_time.update(time.time() - end)\n",
    "    end = time.time()\n",
    "\n",
    "    # plot progress\n",
    "    if (batch_idx + 1) % 10 == 0:\n",
    "        log = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.2f}s | Total: {total:.2f}s | ETA: {eta:.2f}s | Loss: {loss:.2f} | top1: {top1: .2f} | top5: {top5: .2f}'.format(\n",
    "            batch=batch_idx + 1,\n",
    "            size=len(trainloader),\n",
    "            data=data_time.avg,\n",
    "            bt=batch_time.avg,\n",
    "            total=batch_time.sum,\n",
    "            eta=batch_time.avg*(len(trainloader)-batch_idx-1),\n",
    "            loss=losses.avg,\n",
    "            top1=top1.avg,\n",
    "            top5=top5.avg)\n",
    "        print(log)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save finetuned model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_checkpoint(state, checkpoint='checkpoint', filename='checkpoint.pth.tar'):\n",
    "    if not os.path.isdir(checkpoint):\n",
    "        os.makedirs(checkpoint)\n",
    "    filepath = os.path.join(checkpoint, filename)\n",
    "    torch.save(state, filepath)\n",
    "\n",
    "save_checkpoint({\n",
    "    'model_name': 'finetuned_alexnet',\n",
    "    'state_dict': model.state_dict()}, \n",
    "    checkpoint='checkpoint/alexnet/',\n",
    "    filename='finetuned_alexnet.pth.tar')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# checkpoint = torch.load('checkpoint/alexnet/finetuned_alexnet.pth.tar')\n",
    "# model.load_state_dict({k.replace('module.',''):v for k,v in checkpoint['state_dict'].items()})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load testing set of CIFAR-10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform_test = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
    "])\n",
    "testset = datasets.CIFAR10(root='./data', train=False, download=False, transform=transform_test)\n",
    "testloader = data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate finetuned AlexNet on CIFAR-10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10/100) Data: 0.11s | Batch: 0.64s | Total: 6.44s | ETA: 57.96s | Loss: 0.73 | top1:  77.40 | top5:  98.80\n",
      "(20/100) Data: 0.08s | Batch: 0.61s | Total: 12.24s | ETA: 48.95s | Loss: 0.77 | top1:  77.60 | top5:  98.60\n",
      "(30/100) Data: 0.07s | Batch: 0.60s | Total: 17.93s | ETA: 41.85s | Loss: 0.81 | top1:  77.33 | top5:  98.33\n",
      "(40/100) Data: 0.07s | Batch: 0.59s | Total: 23.43s | ETA: 35.15s | Loss: 0.82 | top1:  77.12 | top5:  98.28\n",
      "(50/100) Data: 0.06s | Batch: 0.58s | Total: 29.14s | ETA: 29.14s | Loss: 0.83 | top1:  77.12 | top5:  98.22\n",
      "(60/100) Data: 0.06s | Batch: 0.57s | Total: 34.34s | ETA: 22.89s | Loss: 0.82 | top1:  77.03 | top5:  98.30\n",
      "(70/100) Data: 0.06s | Batch: 0.57s | Total: 39.64s | ETA: 16.99s | Loss: 0.82 | top1:  76.87 | top5:  98.31\n",
      "(80/100) Data: 0.06s | Batch: 0.56s | Total: 45.04s | ETA: 11.26s | Loss: 0.83 | top1:  76.76 | top5:  98.33\n",
      "(90/100) Data: 0.05s | Batch: 0.56s | Total: 50.84s | ETA: 5.65s | Loss: 0.82 | top1:  76.88 | top5:  98.34\n",
      "(100/100) Data: 0.05s | Batch: 0.56s | Total: 55.64s | ETA: 0.00s | Loss: 0.82 | top1:  76.88 | top5:  98.40\n",
      "Final accuracy: 76.88 (top1), 98.40 (top5)}\n"
     ]
    }
   ],
   "source": [
    "# switch to evaluate mode\n",
    "model.eval()\n",
    "\n",
    "batch_time = AverageMeter()\n",
    "data_time = AverageMeter()\n",
    "losses = AverageMeter()\n",
    "top1 = AverageMeter()\n",
    "top5 = AverageMeter()\n",
    "end = time.time()\n",
    "\n",
    "for batch_idx, (inputs, targets) in enumerate(testloader):\n",
    "    # measure data loading time\n",
    "    data_time.update(time.time() - end)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)\n",
    "        \n",
    "        # compute output\n",
    "        outputs = model(inputs)\n",
    "\n",
    "    loss = criterion(outputs, targets)\n",
    "\n",
    "    # measure accuracy and record loss\n",
    "    prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))\n",
    "    losses.update(loss.data.item(), inputs.size(0))\n",
    "    top1.update(prec1.item(), inputs.size(0))\n",
    "    top5.update(prec5.item(), inputs.size(0))\n",
    "\n",
    "    # measure elapsed time\n",
    "    batch_time.update(time.time() - end)\n",
    "    end = time.time()\n",
    "\n",
    "    # plot progress\n",
    "    if (batch_idx + 1) % 10 == 0:\n",
    "        log = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.2f}s | Total: {total:.2f}s | ETA: {eta:.2f}s | Loss: {loss:.2f} | top1: {top1: .2f} | top5: {top5: .2f}'.format(\n",
    "            batch=batch_idx + 1,\n",
    "            size=len(testloader),\n",
    "            data=data_time.avg,\n",
    "            bt=batch_time.avg,\n",
    "            total=batch_time.sum,\n",
    "            eta=batch_time.avg*(len(testloader)-batch_idx-1),\n",
    "            loss=losses.avg,\n",
    "            top1=top1.avg,\n",
    "            top5=top5.avg)\n",
    "        print(log)\n",
    "print('Final accuracy: %.2f (top1), %.2f (top5)}'%(top1.avg, top5.avg))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "categories = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_rgb_img(idx, dataset):\n",
    "    img = dataset[idx][0].numpy().transpose(1, 2, 0) \n",
    "    img = ((img * (0.2023, 0.1994, 0.2010) + (0.4914, 0.4822, 0.4465)) * 255).astype(np.uint8)\n",
    "    return img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Categoriy: airplane, Possibility: 0.99\n",
      "Categoriy: bird, Possibility: 0.00\n",
      "Categoriy: cat, Possibility: 0.00\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD5CAYAAADhukOtAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAZ3ElEQVR4nO2dXWxc13HH/3N3+SGSEiVKlERLlinbCho7H3LAGikSBG6CpG4QwAlQGMlD4AcjCooYaIAUqOECjQv0IWmbBHko0iq1EadI47hx0hiF0cY1AhgBCie0asuOZDu2Q0WkKH6IEknxY8ndO33Yq4IyzgzJy927is7/BwhanuG5d/bsnb3L89+ZEVUFIeT6J2m1A4SQYmCwExIJDHZCIoHBTkgkMNgJiQQGOyGRUN7KZBG5G8A3AZQA/LOqfsX7/c6uHu3euXsrp2wpIhIed2fZ0qY/z7c2eppNXml282vl2Twv1F3jHP67U2wvc06DKX97BzSON39pBssLl4PW3MEuIiUA/wDgowBGAfxSRJ5S1VPWnO6du3H3/X+R52SbnpI4cyTJ94HGOmab415JU9sPZ17i+Si2LSlZV4h95fhftXDmebMMHzucSW1ORKw6i1WVVdNWRjVsSJ3nlXrXm732Vee5aWIfs1arGeP2tWMt1b//09+ZU7byMf5OAG+o6luqugLgcQD3bOF4hJAmspVgPwDg7JqfR7MxQsg1SNM36ETkmIgMi8jw8sLlZp+OEGKwlWAfA3Djmp8PZmNXoarHVXVIVYc6u3u2cDpCyFbYSrD/EsARETksIu0APg3gqca4RQhpNLl341W1KiIPAPgv1KW3R1X1V+vNk1L4lJ5EYkleeSnl3I23ts8TZ6c1cXbO85wLcKQaOJvM7ta/44ZjFLH9aDdez7J3yXlr5ZxLtd22pW1hP0rGLj2A9jZ7d397l+3/zt4+01ZFybT9ZvR8cHyxYk6BJpYf9uu8JZ1dVZ8G8PRWjkEIKQZ+g46QSGCwExIJDHZCIoHBTkgkMNgJiYQt7cbnwVSAciRjeAkcbgaVk5zizbQsqeNIbtXQOWYttf13VEB7jiNFinNAcfyAhqWtqvfKOOdKE1sO8xJXamn4Em83E4aAnnb7eIMDvaatv3+faRs5O2naUF0Jj6v9unhXsAXv7IREAoOdkEhgsBMSCQx2QiKBwU5IJBS6Gy8iKBnJDjUNl+bJZhrHs2d4ZakajTqnquWu4ebgPLc8u7SuAuE8ucS5V6SGzRoHADjXQK1yyfHDvozbSuG06p4Oe87BgX7T1rfTTtOemblo2s6eGzdtVlkqt/yYabHhnZ2QSGCwExIJDHZCIoHBTkgkMNgJiQQGOyGRUHwijJHs4Ek8Fm7Xl5zSm1ffzcJpLoL87ZPyYr1/O3Kd13jEqf3mt2syLi2xLzmpLZm2ixMjpq2zZHvSf+j24Pitg4P2nD472WVledG0vTU6bdouLdmyYtVcE0d6M0w5Sw0SQq4nGOyERAKDnZBIYLATEgkMdkIigcFOSCRsSXoTkREA8wBqAKqqOrSBOcHxPFlqnrzW6JZRPl5roqKlN+N8jhuu6pm3zl8SbrsEtdsgtTk5e7u6jeMBWF2YMW39veHz7e+35TVV+1znJuZN29iMLR2uoMO0iYSft7e+SQ5JtxE6+x+qqi0wEkKuCfgxnpBI2GqwK4CfisgLInKsEQ4RQprDVj/Gf1BVx0RkL4BnRORVVX1u7S9kbwLHAKC7d/cWT0cIycuW7uyqOpb9PwngxwDuDPzOcVUdUtWhzm67pA8hpLnkDnYR6RaR7VceA/gYgFca5RghpLFs5WP8PgA/ziSuMoB/VdX/zHswVygrVEbbPOpqV8X6bsk4XmslN+PQsalzr6gZylCCijkngd3iad/eG0zb7KTRPglAujoXHFdH5hu/MGvaXh+zZb5lteW1xHpdAHQmlvRmy2sr1uEcRS53sKvqWwDem3c+IaRYKL0REgkMdkIigcFOSCQw2AmJBAY7IZFQeMHJHElZucSrQrPN3FMV/H6aWH3DvMKRto/irb4jy9VqYTmsLbFlsl3d9uVYTmxb57Zu0zY3F+4RNzVzwZzz+hlbXpuv2IUj28rtpq0dVdP2jkNhWbHqFAJ97cw522jAOzshkcBgJyQSGOyERAKDnZBIYLATEgnF78aTppIiXHPN21X3lRDbWnISV9qNhJdD+3aZc27ab9eF++3rL5m2snPLmplbCI5XX3/TnHO5Yu+ql8SuT9ddspWGdx4+aNr27d8XHH/1rVFzjrVT772WvLMTEgkMdkIigcFOSCQw2AmJBAY7IZHAYCckEii9bQJLvCq6wZOP0V7JcTKBndxRFtvW6Vw9B3fvDI7f/o6bzDntWDZtYzW7tVJSs5NMFhfC0lulNmHOKW/ba9p6urpM27sHwxIaAAwesI95diqcrDM6bvuY5kgP452dkEhgsBMSCQx2QiKBwU5IJDDYCYkEBjshkbCu9CYijwL4BIBJVX1XNtYH4AcABgGMALhXVS9u6IyGYiANbvGUJAW+jzn17rxaeHmfsZvZZBhLTg26jpJd7Kyv284AO7jXzmC7ZaA/OL5zu90i6cLEtGlLq3bbqJLzUksaluWqi/PmnB09e0zbgb1hSREA+ndsM22zF+26diO/PR8cX6rYkiJKdvadxUYi4jsA7n7b2IMAnlXVIwCezX4mhFzDrBvsWb/1t78t3QPgsezxYwA+2WC/CCENJu9n3X2qOp49Po96R1dCyDXMlv+w1fofpeYfhCJyTESGRWR4eeHyVk9HCMlJ3mCfEJEBAMj+n7R+UVWPq+qQqg51dvfkPB0hZKvkDfanANyXPb4PwE8a4w4hpFlsRHr7PoC7AOwRkVEAXwbwFQBPiMj9AM4AuHejJ7QKH+bJHGu0XJcXz49m+OgdMTFW0ms/1N9tyzg39Nly0v5eW5bb1ROW2ERtmc9TS3fvDkt5ALC0ZMtoy8vhTLqFRbtY5q4O+2rsEbuo5PJSOMMOACYu2n/CTs0bGX1OO6mycV15hUXXDXZV/Yxh+sh6cwkh1w78Bh0hkcBgJyQSGOyERAKDnZBIYLATEgktKDhpyRp+x7HwjJwZZV4mWi6lrFh5zTN2t4Vf0r07us05Rw71mbaF6TOm7cT/DJu2nrv+KDi+a5fdz6293c6I6+u3v5G9MGvfs3Z0d4bnLNpS5MLlOdM2O2GvR5oeMG3Tl+3CnasSltgSJ51PrCKhzrXBOzshkcBgJyQSGOyERAKDnZBIYLATEgkMdkIioVDpTQAkRu8wcXqKQcLvSV6/K0+26Cnb0ltJ7aymjvawRJJ675mG7wDQachkgJ8d1lY2+rkB6N+xPTi+p9fOXtu3x5blRpwyotPTU6bt7NnfBMd7e28357S12c+ra5vt4/YuOzusZ1s4oy91LrezZ8ZN24kTJ03b6MnTpu3g7b9v2spJWB5cqdlO5skS5Z2dkEhgsBMSCQx2QiKBwU5IJDDYCYmEYhNhBFCj0Jif1BK2idq7lZ2Jbdtesnfc9/XY9dgOHBwIjiftXeactjY7uaPD2Y2H0bYIAMRJ5Gk3BIqKUx9tdsosDoxa1fajzUlcGTnzRnD80E03mHN2ONWHdZu9xpLYyoWUwwvSZowDwJ5+u97dDTfsN23zmLX9UHsdEw23thInPKvGfdptDebYCCHXEQx2QiKBwU5IJDDYCYkEBjshkcBgJyQSNtL+6VEAnwAwqarvysYeBvA5AFcyIR5S1afXO1aKBJVSOCGjjHCbHgAo1cJS2a5ttvtL518zbRNzF0zb4aE7TNvu7WGpqa3DkdcceUoSO/FDjOQIAEjEft5lQ46sdNiSYmXF9nF2dtq0efX6FpbCLY2mp+zkmY6y7WNas9s1wVbesFoNi1Fp6iRRddhJQ++6407bjZ12As3ZC3b7p7SUQwHPkQmzkTv7dwDcHRj/hqoezf6tG+iEkNaybrCr6nMAZgrwhRDSRLbyN/sDInJSRB4VkV0N84gQ0hTyBvu3ANwC4CiAcQBfs35RRI6JyLCIDFcW7Na6hJDmkivYVXVCVWuqmgL4NgBz10JVj6vqkKoOdXSHq6gQQppPrmAXkbUZIZ8C8Epj3CGENIuNSG/fB3AXgD0iMgrgywDuEpGjqAsAIwA+v7HTqZn905Ha0ts7B8Otf27abWdCXeq09xS3dR40bR1dtuwyff58cLzdkd66Om0Jratnh2krtdt+eNlmMLIKy07durZ2u77btk7bj94dtv8rRpG3iYkJc07ZkSJ11c4a87g0F5a8Li/ZmY8rzqmWV22d79xFO7Ow3G1vayXW83YkRTMT1J6yfrCr6mcCw4+sN48Qcm3Bb9AREgkMdkIigcFOSCQw2AmJBAY7IZFQaMHJktbQXQ1/i+62G/vMeR9496Hg+KWxcIshAFgUOy2ow2kltKp25tXyUjjzapeTJdXu2Lq6bD/EyYSqORlgC4aPWrN1nI5OW/IqOZLd9u32l6Rm5sLFF88Z8iUAbOuwZcrKgi1rjZ2z5bxXXw9fI5WafZ+7+bZ3m7Zyd69pa9+xx7Spk6lYNV4aL7HNyzi04J2dkEhgsBMSCQx2QiKBwU5IJDDYCYkEBjshkVCo9CYCdJbDgkL/TkfGmTwXHH/xxAvmnNExO+vtyHvsDKTdN4Qz7ACgqxSWoZJOp1Bip9O/zOkDB6PIZv2gtvRm9YFLyvb7uohjc/OobNvSUjiLcXUl3NcMACYnbFnutVOvmrZzZ+15oxOXguMzi/Ya7r3Vlt56+2zpzUmIg9OeD0jDxnq5CON4lvbmnIh3dkIigcFOSCQw2AmJBAY7IZHAYCckEgrdjVckWJHwDvSpEXtHVZcuBsenp+yd3ZXSTtM2MmfvIp9P50xbdzm8O9rRbi9jb6+9qz7QZ+/U97aFa7gBwLaSveOqaVgxkMRplbUYbtUEAGnq7Ag7O79Li4vB8faOdnPOpVl77cdGx0zb/IJdv7CyEl7Hnbtt1aXcZV87FSdkqs5aJU5ai7WOatTxA2AKIeqch3d2QiKBwU5IJDDYCYkEBjshkcBgJyQSGOyERMJG2j/dCOC7APahXhbruKp+U0T6APwAwCDqLaDuVdWwRpahAFbS8Ckn52z5JEnDck1p92FzTqfYySnzVbuu2vxsWDICgBkNJ0+IU++ubSbcfggAxsbt99rbb9xt2gb32klDaiTrrFqFzgAsXrZ9BOy1OjcVrjMHADOXwz2Ujg4eMecc2r/XtN18aNC0LVRsKfXUm+PB8WrZrnfX3WsnSi07yS6OyS0ol1qJMMZ4HVN7M9nInb0K4EuqehuA9wP4gojcBuBBAM+q6hEAz2Y/E0KuUdYNdlUdV9UT2eN5AKcBHABwD4DHsl97DMAnm+UkIWTrbOpvdhEZBHAHgOcB7FPVK5+RzqP+MZ8Qco2y4WAXkR4ATwL4oqpe9b1GrX/fL/jXgogcE5FhERleXvD+NiSENJMNBbuItKEe6N9T1R9lwxMiMpDZBwBMhuaq6nFVHVLVoc5u+7vghJDmsm6wi4ig3o/9tKp+fY3pKQD3ZY/vA/CTxrtHCGkUG8l6+wCAzwJ4WURezMYeAvAVAE+IyP0AzgC4dyMnLFkZPmJLPGkp3EIpdbKu1K2dZs8T8ep+hX2sOeeqOKXkVip2ttlNA/Y8LdmyohiSY+rVtHPkmqqRRQcAS1Wnhl5HWDrcf+AWc8qth+0nvVKxpdnZFXs9FjpHg+MX5sJtyAC/9pujsiLx5DXvWjVsqTrXsFGDzrvu1w12Vf057MqCH1lvPiHk2oDfoCMkEhjshEQCg52QSGCwExIJDHZCIqHY9k+wt/U9ycAULawWOOviSG/urM3LHb6Pznut05KplHjtmsLPrZqGs9AAoOb4cXHRntfRa39Deu+OcLHEbd07zDlSDkusAFCt2MUXz1+wky0tiW215uWoeUU2815zNtbV6OW8+VdqGN7ZCYkEBjshkcBgJyQSGOyERAKDnZBIYLATEgmFSm+/G+SQVnKqMV6vNI9SYmeiSTVcFDOt2llvFbUvg8l5Z57YfdvakrBkt5rai1Vz+tEtVp3+fNOXTFvVkNjUuc/VajmzKV2ZNd9r3Uh4ZyckEhjshEQCg52QSGCwExIJDHZCIqHQ3fh6venW70o2nBxderZCrWYnp6TLC+FxJ4FjbsV+AlOXw7v7AJAm9m78qlHHbblqn6ua2C2Zxi+G2zgBwNSc3bIrtRKKnJ3/mp1zg2a8oJInoSuHksM7OyGRwGAnJBIY7IREAoOdkEhgsBMSCQx2QiJhXelNRG4E8F3UWzIrgOOq+k0ReRjA5wBMZb/6kKo+ve4Zr3HlrWAVzWR11Za8KpWKadPVsCxXUVvWOj8za9qWVmyZT5yEHK2FbTPzYWkQAE69MWLafn3mnGlbcS7jpGS07MrRWgkAUucqSJy2UZ5UZiZE5W5vFmYjOnsVwJdU9YSIbAfwgog8k9m+oap/v+mzEkIKZyO93sYBjGeP50XkNIADzXaMENJYNvU3u4gMArgDwPPZ0AMiclJEHhWRXQ32jRDSQDYc7CLSA+BJAF9U1TkA3wJwC4CjqN/5v2bMOyYiwyIyvLxgt8klhDSXDQW71Jt+Pwnge6r6IwBQ1QlVrWm9mfW3AdwZmquqx1V1SFWHOru3N8pvQsgmWTfYpf4t/UcAnFbVr68ZH1jza58C8Erj3SOENIqN7MZ/AMBnAbwsIi9mYw8B+IyIHEVdrRoB8PmmeHhd49Rjc1KvKiu2LJcgLDXNLtkS2vmZOdPm+Sipkx5mtKgan7xgThmfmDJty6lTd69kZ99Z7otTCy8pOdJb6shhbkcp25gatpqb2WbV1rPnbGQ3/ucIL9n6mjoh5JqB36AjJBIY7IREAoOdkEhgsBMSCQx2QiLhum3/lKuIH5qR2WZLIQJbukq8/DvnudUk/JJOz9otkpYdKQ9OZhs8Gcoo9FhxJMVE7HOljk2cDDZzqQxpsH480wRxXhe3nZfjoxo2v6bk5q9U3tkJiQQGOyGRwGAnJBIY7IREAoOdkEhgsBMSCYVLb0X1essrveU6l2Nrc55vu/NW297mvDRlO8vr8nI4G2ri4kVzjhgZVAD8nmiuHGY9b0c2tL2AeNl3zjxb1spXHFKc7DVPXkvVebHNgpP2HLPgpOMD7+yERAKDnZBIYLATEgkMdkIigcFOSCQw2AmJhGKlNwXSHH2tkgJltDwkYssxHWXb9za1paupi3aN/RWn19tyNbyO8wvL5hw4EponbPl98cJWLfi1tDLRrCKP3py60ZvnZbbZx7SSB103zNV34sg+HCHkeoLBTkgkMNgJiQQGOyGRwGAnJBLW3Y0XkU4AzwHoyH7/h6r6ZRE5DOBxALsBvADgs6q64h9N7a3Ha3vD3SVRO4VDV+22S0tV2zY6bS/luFM/zVIuVp16cTV3h9zfc78W8Ha6LZu/O56vzpxvM022j17STQ5VYyN39gqAD6vqe1Fvz3y3iLwfwFcBfENVbwVwEcD9mz47IaQw1g12rXM5+7Et+6cAPgzgh9n4YwA+2RQPCSENYaP92UtZB9dJAM8AeBPAJVW98jl0FMCB5rhICGkEGwp2Va2p6lEABwHcCeD3NnoCETkmIsMiMry8cHn9CYSQprCp3XhVvQTgZwD+AMBOkf/vSHAQwJgx57iqDqnqUGd3z5acJYTkZ91gF5F+EdmZPd4G4KMATqMe9H+S/dp9AH7SLCcJIVtnI4kwAwAeE5ES6m8OT6jqf4jIKQCPi8jfAPhfAI9s5ISWzGDXLANSQ2ryEmTcZAaHPLXrvHpmtdSR5ZwElGpi15mreu/RqSXnefXRvGQXpx5bMeUEm4IrrzmSV+pImI4JNa9VVo5rNc+cdYNdVU8CuCMw/hbqf78TQn4H4DfoCIkEBjshkcBgJyQSGOyERAKDnZBIkLwSVa6TiUwBOJP9uAfAdGEnt6EfV0M/ruZ3zY+bVLU/ZCg02K86sciwqg615OT0g35E6Ac/xhMSCQx2QiKhlcF+vIXnXgv9uBr6cTXXjR8t+5udEFIs/BhPSCS0JNhF5G4ReU1E3hCRB1vhQ+bHiIi8LCIvishwged9VEQmReSVNWN9IvKMiPw6+39Xi/x4WETGsjV5UUQ+XoAfN4rIz0TklIj8SkT+LBsvdE0cPwpdExHpFJFfiMhLmR9/nY0fFpHns7j5gYjYqZEhVLXQfwBKqJe1uhlAO4CXANxWtB+ZLyMA9rTgvB8C8D4Ar6wZ+1sAD2aPHwTw1Rb58TCAPy94PQYAvC97vB3A6wBuK3pNHD8KXRPU85F7ssdtAJ4H8H4ATwD4dDb+jwD+dDPHbcWd/U4Ab6jqW1ovPf04gHta4EfLUNXnAMy8bfge1At3AgUV8DT8KBxVHVfVE9njedSLoxxAwWvi+FEoWqfhRV5bEewHAJxd83Mri1UqgJ+KyAsicqxFPlxhn6qOZ4/PA9jXQl8eEJGT2cf8pv85sRYRGUS9fsLzaOGavM0PoOA1aUaR19g36D6oqu8D8McAviAiH2q1Q0D9nR1+d4Zm8i0At6DeI2AcwNeKOrGI9AB4EsAXVXVura3INQn4Ufia6BaKvFq0ItjHANy45mezWGWzUdWx7P9JAD9GayvvTIjIAABk/0+2wglVncgutBTAt1HQmohIG+oB9j1V/VE2XPiahPxo1Zpk5950kVeLVgT7LwEcyXYW2wF8GsBTRTshIt0isv3KYwAfA/CKP6upPIV64U6ghQU8rwRXxqdQwJpIvfDfIwBOq+rX15gKXRPLj6LXpGlFXovaYXzbbuPHUd/pfBPAX7bIh5tRVwJeAvCrIv0A8H3UPw6uov631/2o98x7FsCvAfw3gL4W+fEvAF4GcBL1YBsowI8Pov4R/SSAF7N/Hy96TRw/Cl0TAO9BvYjrSdTfWP5qzTX7CwBvAPg3AB2bOS6/QUdIJMS+QUdINDDYCYkEBjshkcBgJyQSGOyERAKDnZBIYLATEgkMdkIi4f8Ao6yD4fVUn9YAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "model.eval()\n",
    "\n",
    "img_idx = 10\n",
    "rgb_img = get_rgb_img(img_idx, testset)\n",
    "imshow(rgb_img)\n",
    "output = model(testset[img_idx][0].unsqueeze(0))\n",
    "score = output.softmax(1).detach().numpy()\n",
    "\n",
    "top3_idx = np.argsort(-score)[0][:3]\n",
    "for i in top3_idx:\n",
    "    print('Categoriy: %s, Possibility: %.2f'%(categories[i], score[0][i]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Categoriy: frog, Possibility: 0.88\n",
      "Categoriy: cat, Possibility: 0.10\n",
      "Categoriy: dog, Possibility: 0.02\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD5CAYAAADhukOtAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAfPElEQVR4nO2dW4ylV5Xf/+vcL3Xva/XNbts9nhgGDNOxmAwiZEYzctBIBilC8ID8gKZH0SAFafJgESkQKQ9MFEA8EZlgjSciXDKAsCKUQKzREBRhaHtMu7Fxu213t7u7ui5dVV2nqs79rDyc05O2s/+7yl1Vpxr2/ye1umqv2t+3zj7fOt85+3/WWubuEEL85pPZbQeEEMNBwS5EIijYhUgEBbsQiaBgFyIRFOxCJEJuK5PN7GEAXwaQBfCf3f3zsb/PZ82LWQs7UsjTeblc2E3v9uicTrtJbRnjr3GFPF8SJlP2et23PQcAMhnuR8wWXsGbxrA1m83yc2X5uWLKbK/H1z9Dz8e9jx0vtsbsMcfO1nN+Lu/d7nMW8SPypLU74cfWaHb4JHINN1pdtDrd4NnsdnV2M8sCOAfgjwBcBvBzAB939xfZnJFCxn/nYDio9x89TM+1Z8/e4HhzZZXOuX7lArWNlErUdnj/fmrrtVvB8draMp3T7rWprVIqU9tImdvykcC1XHh9R8fG6JzyaIXa2pEAXKvzF9SR6kRw3CLBslZfo7bVdW7L5XgkZcn5Gs0GndOK2MplvlblCr+uCpHb6sz89eD4udfD4wDgudHg+M/OXcPKeiu4IFt5G/8QgPPu/pq7twB8E8AjWzieEGIH2UqwHwbwxi2/Xx6MCSHuQLb0mX0zmNkpAKcAoMDffQohdpit3NmvADh6y+9HBmNvwt0fd/eT7n4yH9nAEELsLFsJ9p8DOGFmx82sAOBjAJ7aHreEENvNbb+Nd/eOmX0KwP9EX3p7wt1/GZuTzRgmSuFTTk/x3eJGqx4cL+S5kjAyxndN87nIbnYpIg0R2bBaHKFzCoUCtY2MTnI/iBwDAE7WAwDyhWJwfGJvWNHoT+Ky59hoeNcXADI3uBpSKITXJB+R+TJFvpudifgYkxVbRJ7NOPejUuLXYjHyfFbL3MfYbXW9WwuOL6zza3H5xlJwvNni182WPrO7+w8A/GArxxBCDAd9g06IRFCwC5EICnYhEkHBLkQiKNiFSIQd/wbdrWTMUc2GE0O8tkjn7Z8IS1RrNZ5ksljnslBxpEpt9ci88cmwJJPpccnFslxOaoDPK1a5dNhy/ribnXDCSHd5ls7pOZeuRuo8KcQy/PLpengdGx3uey6Scbjv4AFqq6+vU5uvhddjD7mmAKBS4WtfLIalTQBorIclNAC4shCWygDgRiO8/pMHjgbHAeDitXPB8W4sY49ahBC/USjYhUgEBbsQiaBgFyIRFOxCJMJQd+OL+SzuPrQnaHOL1ARjO/XLfPdzOlYPLJIs4Cs8yWRlPbyT3CUJMgBwo8lLVtUiu/F7p3l5rO4aVwymx8KJGi2yKw0A7TZf+5V5rpIgVo+NFS+IZDkXyjzpZqLLE3nWV/lja9DHHSnHNsl36ldv8LW6cnWG2p4/9/9lf/8Da5nx4Ljn+M5/pxOuT+eRx6U7uxCJoGAXIhEU7EIkgoJdiERQsAuRCAp2IRJhqNJbqVzBA+94V9DWXePJDKurN4LjzSqXrvKRhIv5df4a9+o8l97WiTzYy/BEkvm1cBcZACiM84Sc2lUu2eWbvBPLA4dPBMcPjPIOM60WT3ap1fnzsh7pFtNoh9e/2eDr21jjfsy3eCJPqcjXsZgP18LrRpKJFpZ40srcHO/SMnONS8FX5/lz5uXw426u83NNVcPXXLPGZWXd2YVIBAW7EImgYBciERTsQiSCgl2IRFCwC5EIW5LezOwCgBqALoCOu5+M/b0DaJFMqbGJCTqvWgpncnW7PEuqEcmuev7MJWo7Ox+W+QDAqmE/chkuAY6OTnFH8jyr6fp1LrtYg8t5//u588Hx++7iWWP3HOFZXvv2H6G2kZFpausQNXJtnWfs9Zw/rrl5nlG2FqtBlw1ngbV6XKK6vsyz6K7M8utjYYn70YpkWq6uhZ/rcpFLuvlC+JqLtU7dDp39n7n7wjYcRwixg+htvBCJsNVgdwA/NLNnzezUdjgkhNgZtvo2/v3ufsXM9gP4kZn9yt1/fOsfDF4ETgHA1CivoS6E2Fm2dGd39yuD/+cAfA/AQ4G/edzdT7r7yZEy720thNhZbjvYzaxqZqM3fwbwxwDObpdjQojtZStv4w8A+J6Z3TzOf3X3/xGb0Gy38erM1aDtyASXqO6aCGc1ZXv8tep6gxfeW1zmmVeZPM+gqlTDthtLPNupCy7HjGb58h/cz9sdzV2bp7ZXroWz5a5F2mvNrfC1/8Dv/hNqm47Iih2Es8rKRS5TrqzybLMSkZoAYHWVS3Y1IvWtRuS6hRtceluNyJ7NSBZgqcyv1f3TYdk5l+HX8NpK+HFZRHu77WB399cAvPt25wshhoukNyESQcEuRCIo2IVIBAW7EImgYBciEYZacNIBdHphOeH0mZf4xN86FBy+/xjP5OrVeIG/eiPcJwsAVlZ5IcJ8JZylViry7LW1dS7z9Tq8wGKrUqG28fExasuVwhJVM1Kk8uUZnolWepE/L7ksl5oynbB81XOe/dVucVmrV+dyWDXLn7Ouhdd4vcefl3aDr0enxc9VjhRALRmX0e45Fs46zPGkN8zOhq/hy5G+d7qzC5EICnYhEkHBLkQiKNiFSAQFuxCJMNTd+Gwmi4mx8M7j+gTf5Tz7arj+2MEpnrRSKvLXsdEJnlc/AW5rk5ZBk+O8Fl5sF7xV56pAPZIUMrmHZztM7QknVSzdWKFzYm2X/v7l16it3eI7v+86vi84XoldcZEadLkc3/nvRnb4c3lyHWT4GvIzAbkS3yJfbfEd906XH5VdI8UKTwkvF8ILmYlkwujOLkQiKNiFSAQFuxCJoGAXIhEU7EIkgoJdiEQYqvQGB3rtsARx+Ph9dNr5l18Njv/k2Qt0zu/+znFqm5rgUtlSj8s/yythicR7XDacmOTnWqvxeZFyZoBH6uuRtlFjY+N0TjHHL4O5WV7v7vTL4XqCALC4FG6TdN9h3mrqwCSXmizDF6Te5uu42gw/n6t1/jw3OvxcPdK+DADyRe7/eJlLuuvrYf9bdZ6s40TKi1waurMLkQoKdiESQcEuRCIo2IVIBAW7EImgYBciETaU3szsCQB/AmDO3d85GJsC8C0AdwO4AOCj7s7TtAY0mw28eu6VoG3i2D18YiUsX126tECntH95hdpqHZ4ZtLLCWzlNjoaz7JotnjU2ObWf2kZHucSzvMyz1EqRBpm1tbBcc3XmGp2z7wD3cd/BcPYaACwTeQ0Azs+F67gtR9onHd3L6+4dmOQ2y3DprdkKZ8SxWogAUI5IaJUq96M4wuVNi2SjrSyHQ6e2xrMKC6QdloM/rs3c2f8KwMNvGXsMwNPufgLA04PfhRB3MBsG+6Df+lu7Aj4C4MnBz08C+PA2+yWE2GZu9zP7AXe/WVHiGvodXYUQdzBb/rqsu7sZL4ptZqcAnAKASkH7gULsFrcbfbNmNg0Ag//n2B+6++PuftLdT5ZYiSAhxI5zu9H3FIBHBz8/CuD72+OOEGKn2Iz09g0AHwSw18wuA/gsgM8D+LaZfRLARQAf3czJRqsV/NPfezBoe+ZXXCo7cPTe4Pjdh7lk9H/+7hlqm1vh2URW4rKLW/i1cb3OCx625t+6t/n/GB/nMk4m0lopm+UyTo5IMs0Ob2l05QrPXquO8KKe+w9NU1udSGzXF3kW3cqlZWqbW+b+3zUdLrIJAE4y4iLJYXjXu99FbYcOHaO22YgUeXUmXDQVANZXw4+7WCnTORlyLcLW6ZwNg93dP05Mf7jRXCHEnYM+RAuRCAp2IRJBwS5EIijYhUgEBbsQiTDUgpP5fA7TB/YEbVOXw4USAaBTD0saR07wTLl3vvud1PbDv/s5tVVLXGpabYcFm7UGl96szqWQbqRQZT5SBLJe5/3jmq1w/7hylcs4vU6k51yDn+vaDJfsDk6Hv0Gdi2TYLS7wLMbFVe7jaERKHSmGZcqDB7lsuHcqfI0CwPG7ufR2/F7+nP3q3Ai1zc+GZbl2nV8fuRyXiBm6swuRCAp2IRJBwS5EIijYhUgEBbsQiaBgFyIRhiq9dR2okUJ/ExO8F9bsYjhzrN3gBXLuv+8wtV28zDOQXp/l2VXFsXBGWaaUpXMy3fAcAOhFXmtL5YgEWOM+eo9IbwUu1eQjmW095xl29QaXFefmwjLa9KFDdM59J+6ntvMvv0xtbyzyIpZH9oclr6MTXALsRh7z8gp/zJN7DlLboUNHqO3wdNiXtVd4FuBIMZwxmYkUttSdXYhEULALkQgKdiESQcEuRCIo2IVIhKHuxjc7XbxGdrurk7zNUKERTghYrPH2OMUSr++GDN+xvHIp0ibp8N7g+NTkJJ1TX+fJDNnIS22nw2vQjUR2z9vd8G58ocQTYZpN7mMmslZ79vDnbHExnNh08cIlOufYXUepbWR8jNrmrvG6dr0eud4qPOkmdu28+Mo5aitc5HUU7zl+F7VN7Q3vxpcv8ePlSSJMrM2U7uxCJIKCXYhEULALkQgKdiESQcEuRCIo2IVIhM20f3oCwJ8AmHP3dw7GPgfgTwHc1Dw+4+4/2OhY9UYLz798OWgbK3OpaW01LJ+0r/JEgXqT14Wr1WrUVi7wpJb5K0thA29ii7GJUWor5Pi5Oi1e+y3DpyFn4ac0JsmUKzwJKZvliTzFSLKOe/j5bLe5THbh9YvUFpPliHIFAFgkstyFN7gfk1NcSs3m+HXaWiLXBwCLPGeXL88Gx2cjSVnVavj6jkm2m7mz/xWAhwPjX3L3Bwf/Ngx0IcTusmGwu/uPAfDuhEKIXwu28pn9U2Z2xsyeMDP+vkcIcUdwu8H+FQD3AngQwAyAL7A/NLNTZnbazE6zmuZCiJ3ntoLd3Wfdvev9XZivAngo8rePu/tJdz9ZLAz1q/hCiFu4rWA3s1vbaXwEwNntcUcIsVNsRnr7BoAPAthrZpcBfBbAB83sQQAO4AKAP9vU2SwD5MMyT6PD5bCZ2bCksbDA64Edj9SgO3GMZ2tluOKFF14OSyTdSPukSpnXfisWuKy1sBZueQUApUJE8iIy4MoKP97Y+AS1dXpcyvF6g9rGx8aD490ul0RnrvKMw7lZLpUdOcJludqNcGbktXm+Hj/9Kb93lce4hDlS4dlyly5x/5mPxcjxHnjHieD4K/Pc9w2D3d0/Hhj+2kbzhBB3FvoGnRCJoGAXIhEU7EIkgoJdiERQsAuRCEP+lovDPCy9WORlp5wPy1eVHJdxfuvYNLWVc7xdULbFZajF5bA8mC1xCY0LNUChWOR+5PlT0+nwx53Jhue1Wvwxz86EJUUA8MgjyOT5454kRTirETlpbJxnCN4gaw8A1xa4/2NT4WNer/P1aES+6Jnt8Md8dYZnYXrkmN4LP5/T0/xa7FlYEo10rtKdXYhUULALkQgKdiESQcEuRCIo2IVIBAW7EIkwVOkta4axfFgbqPEENuw7cCA47u2rdI43uVSTy/BMtMkxLq184B+Hs6t+9QaXXNodLvG0PNJjLSJrtVb5vGanHhzP5/jxem3uY7vFZb4CeBXFtVo4q2x9nRdRzEaa31VGuGR3Y3mF2vbsmwqOj+0ZoXNqS/z5zLf4tdPu8PXIZHn24OS+cIZgqcLD88rlN8I+RCRW3dmFSAQFuxCJoGAXIhEU7EIkgoJdiEQY6m58p93G4uxM0NbK8N3iRibsZifypf96k9dHK2T5DrP3+LyJ0XD9vMkqX8ar6+HdcQCorfLX2ixJ/gGA9Q6XLjqd8E69R9o/NRq88J4795ElcACAOUnIiZyrF7n3xPwv5PharZKd+mqV7+6XytwWWyuL9HhaX+fzctmwQnF0ktdK7LXDdesy4K3IdGcXIhEU7EIkgoJdiERQsAuRCAp2IRJBwS5EImym/dNRAH8N4AD67Z4ed/cvm9kUgG8BuBv9FlAfdfdwn6YB3V4Hq6vhP+nleWJCpToWHh/lc0b2hJNnAGDvSESqWeItiCqkXdOBKe7HDSKRAIBFatD1jD81jQqX3nrr4YSLZqSwWibDJSN3LuW025FaeI1wQkarzf3I5vhj7kWKFBbz3P9GPSxF1tf4Gk5O7aG2hXnexqkTSRoan+JdzYv58HOW4bkzOHFXOCnrp69yqXczd/YOgL9w9wcAvA/An5vZAwAeA/C0u58A8PTgdyHEHcqGwe7uM+7+3ODnGoCXABwG8AiAJwd/9iSAD++Uk0KIrfO2PrOb2d0A3gPgGQAH3P3m1+Guof82Xwhxh7LpYDezEQDfAfBpd3/TdxC9/8Eu+OHOzE6Z2WkzO93u8M9/QoidZVPBbmZ59AP96+7+3cHwrJlND+zTAOZCc939cXc/6e4n87lYywQhxE6yYbCbmaHfj/0ld//iLaanADw6+PlRAN/ffveEENvFZrLefh/AJwC8YGbPD8Y+A+DzAL5tZp8EcBHARzc6UCaTQalaDdpmF3m2mSGsQUxMhI8FAAvXuQRx176D1DYaabs0NRLOerPcdTrn8tJFamtEWiu1e1x3sSL3sZIph+eAZ10VClyKbETaJEVUObSb4fPlIrXwotl3kaw3RI5pvbCTsey71XVev7A8yq+52gK/DjIFLs9aKSzBNiPXR2Uk3NYqk+Uy5IbB7u4/AW9Z9ocbzRdC3BnoG3RCJIKCXYhEULALkQgKdiESQcEuRCIMteCkWRbZfDiD7eABLjM0G+GCfIVCWAoDgKXlcPshALg4Ey56CQDvOLSX2sarpE1PiWe29Xq8VVMvkgFWLHGpphlpX9VDWCorlMOSHAA0IkUxy2V+iVTLkaw9Ist1O/wxlzpcQutEWk01mzzbrJgN+18n0iAA1GtcBt4/vZ/aStP83rkwv0BtU/vCx6y3uPzqpK2YO5+jO7sQiaBgFyIRFOxCJIKCXYhEULALkQgKdiESYajSG8BfXQ4dOUTnFEphqWlleZnO6YFnLi2RopcAgAyXVlji1cwiz3ZqdLgUko1k2K2u8oKIMfmq3Q1rXr3InHyBv+aPkLUHgPo6lxxRCPdLGx8Ny5cAUKuF+7IBQLsRqb4Yy4gjil2sskK3xeXS9Rv8MY9O8oy4I8fCBSIB4PDhY8Hx1sJlOidPJF2LpCLqzi5EIijYhUgEBbsQiaBgFyIRFOxCJMJQd+ML+RyOHAq31snkeDLD1Fh4l7OQiSTPNHmSxsoaT5JZqPFd/GNHwkkyK2s8caLb5X4UIu2OJio8KaTV5q/R3gtvP6+u8J3u8UrEx0hrpfVIosbk/vDucyFSI215jde7y2a5rVTkflQr4Wsnl+VruLTIVZ5GnasksR3+YpU/11fn3ggblrnK05kibRpYBhJ0ZxciGRTsQiSCgl2IRFCwC5EICnYhEkHBLkQibCi9mdlRAH+NfktmB/C4u3/ZzD4H4E8BzA/+9DPu/oPoyXJZ7JuaCNrcuCRjRJUrRubkIzXXVupcKnvu7Kv8mKSe2dJKRBYqh2vuAUCpEk4WAYDVSJJJLvK4nfgYS2jJ9LjsCeMS4IHpw9TWaIVrvC3W+Fot1yKJNS0+L5/j69Ephv1fW+MSmpN2YwDQJrXfAKCU4c8n6lz69GZY6jswwq/h8YlwQlE2shab0dk7AP7C3Z8zs1EAz5rZjwa2L7n7f9zEMYQQu8xmer3NAJgZ/Fwzs5cA8Jd0IcQdydv6zG5mdwN4D4BnBkOfMrMzZvaEmU1us29CiG1k08FuZiMAvgPg0+6+AuArAO4F8CD6d/4vkHmnzOy0mZ1ei7T/FULsLJsKdjPLox/oX3f37wKAu8+6e9f7Vem/CuCh0Fx3f9zdT7r7yWqZbxIJIXaWDYPdzAzA1wC85O5fvGV8+pY/+wiAs9vvnhBiu9jMbvzvA/gEgBfM7PnB2GcAfNzMHkRfjrsA4M82OpDBkLPw60u5EpGoiHy1FqlZVlvnGUPZAq8Vdm3xKrWdey18zLUmlzuq46PUVqxyaWV+gftfW47UQZsIr9VElb+raq/zdZxf5Odqg8tyrEZarsjXPh/JRCuO8nZYcC4drpIsxlaLy6/FAn8+M5GsvUaXt9G6d4rLcg8eDtc9HB/na1UeC2cqWmQNN7Mb/xOEs/eimroQ4s5C36ATIhEU7EIkgoJdiERQsAuRCAp2IRJhqAUn250O5hYWg7Y9U9yVXLYUHK+3uCy0GCkMuDjHv8m3/yD/1u++Y+FimbNrXLoaGQn7DiDatqjb5nLS3NwCt82H13fvHt526cA+Lg/mI9+DWr0Ra79FHhtpTwUA5Qy3FSvckXyBS4DFetjW6vAWT2OjXPJy3kULjSa/Hscq/Lo6cXQ6ON5qxjLzwsSKXurOLkQiKNiFSAQFuxCJoGAXIhEU7EIkgoJdiEQYqvTW6znWG2Htotzmclh9YSY4fqPGJahMjssxzSUuGR37R/uo7eC+cEbZ+fORIn8ROWl5ZYna2j2elXXwGOnzBcA8LFGt3eDy4MULs9Q2Oc77wFUjxSg7HhaBmpECJq2ILaJ4RQt3tonENjLB51QjxUrri7xP4GSkcGes4OfZq+Feb9VIL8Bj5SPBcYvcv3VnFyIRFOxCJIKCXYhEULALkQgKdiESQcEuRCIMVXrL5/OYPngwaIsVylshWWXzi1y66va41JEtcjlpZZWa8NyZ14Pji5Feb3XjmVArkX5uk/u4vDYyEe6XBwCtZlhqWrjG86Haq3ztq0W+js1IZiHri9eI9ErrRGzFJr9Um4hkhxFVdKTCM9sybS573jPN1/7+feGsSABoOD/mSzPXguPlDM+YLBSnguPtDu9Tpzu7EImgYBciERTsQiSCgl2IRFCwC5EIG+7Gm1kJwI8BFAd//zfu/lkzOw7gmwD2AHgWwCfcPdqmtdtuYXH2YtB2Y43vVrY9vMPYc56Icb3GkxIsUkjM53ltMu+GfSzkeCJMfZkn6xQjCsRYlidqlCNthvLV8JrsOxrevQWAVS5qIOc8kadV5LYGWcZKpJZcN/K8FCJJIcVIIky9Ez5md51fqqXIZfzb9x+ltomIcvH6XLgNFQDkC+EagJ1IeF66EVZyWt2t7cY3AfyBu78b/fbMD5vZ+wD8JYAvuft9AJYAfHITxxJC7BIbBrv3uak+5wf/HMAfAPibwfiTAD68Ix4KIbaFzfZnzw46uM4B+BGAVwEsu//D+67LAA7vjItCiO1gU8Hu7l13fxDAEQAPAfjtzZ7AzE6Z2WkzO11vxUoQCCF2kre1G+/uywD+FsDvAZgws5s7CEcAXCFzHnf3k+5+slwY6rdzhRC3sGGwm9k+M5sY/FwG8EcAXkI/6P/F4M8eBfD9nXJSCLF1NnOrnQbwpJll0X9x+La7/3czexHAN83s3wP4ewBf2+hAmUwX5XxYguiWYjJOOIljYoS31HHwjJZjh7kMNVXiSTLdNaJRRaSwXJ4nXOQyXKqp5Lkf6HF5pUuqteVHR+iczmikx5NzCbMRaVF1cSmcvHR9lSf/+Ci/9ziR0ACgVY8csxH2sZThl/4D9xyjtrv38mtuaY3LaxN7eG3DqXL4uVmu82u4WA3Xycvk+RpuGOzufgbAewLjr6H/+V0I8WuAvkEnRCIo2IVIBAW7EImgYBciERTsQiSCeSSradtPZjYP4Gba214APCVseMiPNyM/3syvmx93uXtQ5xtqsL/pxGan3f3krpxcfsiPBP3Q23ghEkHBLkQi7GawP76L574V+fFm5Meb+Y3xY9c+swshhovexguRCLsS7Gb2sJm9bGbnzeyx3fBh4McFM3vBzJ43s9NDPO8TZjZnZmdvGZsysx+Z2SuD/3l61c768TkzuzJYk+fN7END8OOomf2tmb1oZr80s381GB/qmkT8GOqamFnJzH5mZr8Y+PHvBuPHzeyZQdx8y8wi6YoB3H2o/wBk0S9rdQ+AAoBfAHhg2H4MfLkAYO8unPcDAN4L4OwtY/8BwGODnx8D8Je75MfnAPzrIa/HNID3Dn4eBXAOwAPDXpOIH0NdEwAGYGTwcx7AMwDeB+DbAD42GP9PAP7l2znubtzZHwJw3t1f837p6W8CeGQX/Ng13P3HABbfMvwI+oU7gSEV8CR+DB13n3H35wY/19AvjnIYQ16TiB9Dxftse5HX3Qj2wwDeuOX33SxW6QB+aGbPmtmpXfLhJgfcfWbw8zUAvI3rzvMpMzszeJu/4x8nbsXM7ka/fsIz2MU1eYsfwJDXZCeKvKa+Qfd+d38vgH8O4M/N7AO77RDQf2VH/4VoN/gKgHvR7xEwA+ALwzqxmY0A+A6AT7v7m0rdDHNNAn4MfU18C0VeGbsR7FcA3NpWgxar3Gnc/crg/zkA38PuVt6ZNbNpABj8P7cbTrj77OBC6wH4Koa0JmaWRz/Avu7u3x0MD31NQn7s1poMzv22i7wydiPYfw7gxGBnsQDgYwCeGrYTZlY1s9GbPwP4YwBn47N2lKfQL9wJ7GIBz5vBNeAjGMKamJmhX8PwJXf/4i2moa4J82PYa7JjRV6HtcP4lt3GD6G/0/kqgH+zSz7cg74S8AsAvxymHwC+gf7bwTb6n70+iX7PvKcBvALgfwGY2iU//guAFwCcQT/Ypofgx/vRf4t+BsDzg38fGvaaRPwY6poAeBf6RVzPoP/C8m9vuWZ/BuA8gP8GoPh2jqtv0AmRCKlv0AmRDAp2IRJBwS5EIijYhUgEBbsQiaBgFyIRFOxCJIKCXYhE+L+V4nNGH3PeMgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "model.eval()\n",
    "\n",
    "img_idx = 30\n",
    "rgb_img = get_rgb_img(img_idx, testset)\n",
    "imshow(rgb_img)\n",
    "output = model(testset[img_idx][0].unsqueeze(0))\n",
    "score = output.softmax(1).detach().numpy()\n",
    "\n",
    "top3_idx = np.argsort(-score)[0][:3]\n",
    "for i in top3_idx:\n",
    "    print('Categoriy: %s, Possibility: %.2f'%(categories[i], score[0][i]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
